
<!doctype html>
<html lang="en" class="no-js">
  <head>
    
      <meta charset="utf-8">
      <meta name="viewport" content="width=device-width,initial-scale=1">
      
      
      
        <link rel="canonical" href="https://pytorch-widedeep.readthedocs.io/pytorch-widedeep/metrics.html">
      
      
        <link rel="prev" href="losses.html">
      
      
        <link rel="next" href="dataloaders.html">
      
      
      <link rel="icon" href="../assets/images/favicon.ico">
      <meta name="generator" content="mkdocs-1.6.1, mkdocs-material-9.5.43">
    
    
      
        <title>Metrics - pytorch_widedeep</title>
      
    
    
      <link rel="stylesheet" href="../assets/stylesheets/main.0253249f.min.css">
      
        
        <link rel="stylesheet" href="../assets/stylesheets/palette.06af60db.min.css">
      
      


    
    
      
    
    
      
        
        
        <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
        <link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Roboto:300,300i,400,400i,700,700i%7CRoboto+Mono:400,400i,700,700i&display=fallback">
        <style>:root{--md-text-font:"Roboto";--md-code-font:"Roboto Mono"}</style>
      
    
    
      <link rel="stylesheet" href="../assets/_mkdocstrings.css">
    
      <link rel="stylesheet" href="../stylesheets/extra.css">
    
    <script>__md_scope=new URL("..",location),__md_hash=e=>[...e].reduce(((e,_)=>(e<<5)-e+_.charCodeAt(0)),0),__md_get=(e,_=localStorage,t=__md_scope)=>JSON.parse(_.getItem(t.pathname+"."+e)),__md_set=(e,_,t=localStorage,a=__md_scope)=>{try{t.setItem(a.pathname+"."+e,JSON.stringify(_))}catch(e){}}</script>
    
      

    
    
    
  </head>
  
  
    
    
      
    
    
    
    
    <body dir="ltr" data-md-color-scheme="default" data-md-color-primary="red" data-md-color-accent="deep-orange">
  
    
    <input class="md-toggle" data-md-toggle="drawer" type="checkbox" id="__drawer" autocomplete="off">
    <input class="md-toggle" data-md-toggle="search" type="checkbox" id="__search" autocomplete="off">
    <label class="md-overlay" for="__drawer"></label>
    <div data-md-component="skip">
      
        
        <a href="#metrics" class="md-skip">
          Skip to content
        </a>
      
    </div>
    <div data-md-component="announce">
      
    </div>
    
    
      

  

<header class="md-header md-header--shadow md-header--lifted" data-md-component="header">
  <nav class="md-header__inner md-grid" aria-label="Header">
    <a href="../index.html" title="pytorch_widedeep" class="md-header__button md-logo" aria-label="pytorch_widedeep" data-md-component="logo">
      
  <img src="../assets/images/widedeep_logo.png" alt="logo">

    </a>
    <label class="md-header__button md-icon" for="__drawer">
      
      <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M3 6h18v2H3zm0 5h18v2H3zm0 5h18v2H3z"/></svg>
    </label>
    <div class="md-header__title" data-md-component="header-title">
      <div class="md-header__ellipsis">
        <div class="md-header__topic">
          <span class="md-ellipsis">
            pytorch_widedeep
          </span>
        </div>
        <div class="md-header__topic" data-md-component="header-topic">
          <span class="md-ellipsis">
            
              Metrics
            
          </span>
        </div>
      </div>
    </div>
    
      
        <form class="md-header__option" data-md-component="palette">
  
    
    
    
    <input class="md-option" data-md-color-media="" data-md-color-scheme="default" data-md-color-primary="red" data-md-color-accent="deep-orange"  aria-label="Switch to dark mode"  type="radio" name="__palette" id="__palette_0">
    
      <label class="md-header__button md-icon" title="Switch to dark mode" for="__palette_1" hidden>
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 8a4 4 0 0 0-4 4 4 4 0 0 0 4 4 4 4 0 0 0 4-4 4 4 0 0 0-4-4m0 10a6 6 0 0 1-6-6 6 6 0 0 1 6-6 6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
      </label>
    
  
    
    
    
    <input class="md-option" data-md-color-media="" data-md-color-scheme="slate" data-md-color-primary="red" data-md-color-accent="deep-orange"  aria-label="Switch to light mode"  type="radio" name="__palette" id="__palette_1">
    
      <label class="md-header__button md-icon" title="Switch to light mode" for="__palette_0" hidden>
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 18c-.89 0-1.74-.2-2.5-.55C11.56 16.5 13 14.42 13 12s-1.44-4.5-3.5-5.45C10.26 6.2 11.11 6 12 6a6 6 0 0 1 6 6 6 6 0 0 1-6 6m8-9.31V4h-4.69L12 .69 8.69 4H4v4.69L.69 12 4 15.31V20h4.69L12 23.31 15.31 20H20v-4.69L23.31 12z"/></svg>
      </label>
    
  
</form>
      
    
    
      <script>var palette=__md_get("__palette");if(palette&&palette.color){if("(prefers-color-scheme)"===palette.color.media){var media=matchMedia("(prefers-color-scheme: light)"),input=document.querySelector(media.matches?"[data-md-color-media='(prefers-color-scheme: light)']":"[data-md-color-media='(prefers-color-scheme: dark)']");palette.color.media=input.getAttribute("data-md-color-media"),palette.color.scheme=input.getAttribute("data-md-color-scheme"),palette.color.primary=input.getAttribute("data-md-color-primary"),palette.color.accent=input.getAttribute("data-md-color-accent")}for(var[key,value]of Object.entries(palette.color))document.body.setAttribute("data-md-color-"+key,value)}</script>
    
    
    
      <label class="md-header__button md-icon" for="__search">
        
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
      </label>
      <div class="md-search" data-md-component="search" role="dialog">
  <label class="md-search__overlay" for="__search"></label>
  <div class="md-search__inner" role="search">
    <form class="md-search__form" name="search">
      <input type="text" class="md-search__input" name="query" aria-label="Search" placeholder="Search" autocapitalize="off" autocorrect="off" autocomplete="off" spellcheck="false" data-md-component="search-query" required>
      <label class="md-search__icon md-icon" for="__search">
        
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M9.5 3A6.5 6.5 0 0 1 16 9.5c0 1.61-.59 3.09-1.56 4.23l.27.27h.79l5 5-1.5 1.5-5-5v-.79l-.27-.27A6.52 6.52 0 0 1 9.5 16 6.5 6.5 0 0 1 3 9.5 6.5 6.5 0 0 1 9.5 3m0 2C7 5 5 7 5 9.5S7 14 9.5 14 14 12 14 9.5 12 5 9.5 5"/></svg>
        
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M20 11v2H8l5.5 5.5-1.42 1.42L4.16 12l7.92-7.92L13.5 5.5 8 11z"/></svg>
      </label>
      <nav class="md-search__options" aria-label="Search">
        
        <button type="reset" class="md-search__icon md-icon" title="Clear" aria-label="Clear" tabindex="-1">
          
          <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M19 6.41 17.59 5 12 10.59 6.41 5 5 6.41 10.59 12 5 17.59 6.41 19 12 13.41 17.59 19 19 17.59 13.41 12z"/></svg>
        </button>
      </nav>
      
    </form>
    <div class="md-search__output">
      <div class="md-search__scrollwrap" tabindex="0" data-md-scrollfix>
        <div class="md-search-result" data-md-component="search-result">
          <div class="md-search-result__meta">
            Initializing search
          </div>
          <ol class="md-search-result__list" role="presentation"></ol>
        </div>
      </div>
    </div>
  </div>
</div>
    
    
      <div class="md-header__source">
        <a href="https://github.com/jrzaurin/pytorch-widedeep" title="Go to repository" class="md-source" data-md-component="source">
  <div class="md-source__icon md-icon">
    
    <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M439.55 236.05 244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81"/></svg>
  </div>
  <div class="md-source__repository">
    pytorch_widedeep
  </div>
</a>
      </div>
    
  </nav>
  
    
      
<nav class="md-tabs" aria-label="Tabs" data-md-component="tabs">
  <div class="md-grid">
    <ul class="md-tabs__list">
      
        
  
  
  
    <li class="md-tabs__item">
      <a href="../index.html" class="md-tabs__link">
        
  
    
  
  Home

      </a>
    </li>
  

      
        
  
  
  
    <li class="md-tabs__item">
      <a href="../installation.html" class="md-tabs__link">
        
  
    
  
  Installation

      </a>
    </li>
  

      
        
  
  
  
    <li class="md-tabs__item">
      <a href="../quick_start.html" class="md-tabs__link">
        
  
    
  
  Quick Start

      </a>
    </li>
  

      
        
  
  
    
  
  
    
    
      
  
  
    
  
  
    
    
      <li class="md-tabs__item md-tabs__item--active">
        <a href="utils/index.html" class="md-tabs__link">
          
  
    
  
  Pytorch-widedeep

        </a>
      </li>
    
  

    
  

      
        
  
  
  
    
    
      <li class="md-tabs__item">
        <a href="../examples/01_preprocessors_and_utils.html" class="md-tabs__link">
          
  
    
  
  Examples

        </a>
      </li>
    
  

      
        
  
  
  
    <li class="md-tabs__item">
      <a href="../contributing.html" class="md-tabs__link">
        
  
    
  
  Contributing

      </a>
    </li>
  

      
    </ul>
  </div>
</nav>
    
  
</header>
    
    <div class="md-container" data-md-component="container">
      
      
        
      
      <main class="md-main" data-md-component="main">
        <div class="md-main__inner md-grid">
          
            
              
              <div class="md-sidebar md-sidebar--primary" data-md-component="sidebar" data-md-type="navigation" >
                <div class="md-sidebar__scrollwrap">
                  <div class="md-sidebar__inner">
                    


  


  

<nav class="md-nav md-nav--primary md-nav--lifted md-nav--integrated" aria-label="Navigation" data-md-level="0">
  <label class="md-nav__title" for="__drawer">
    <a href="../index.html" title="pytorch_widedeep" class="md-nav__button md-logo" aria-label="pytorch_widedeep" data-md-component="logo">
      
  <img src="../assets/images/widedeep_logo.png" alt="logo">

    </a>
    pytorch_widedeep
  </label>
  
    <div class="md-nav__source">
      <a href="https://github.com/jrzaurin/pytorch-widedeep" title="Go to repository" class="md-source" data-md-component="source">
  <div class="md-source__icon md-icon">
    
    <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 448 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M439.55 236.05 244 40.45a28.87 28.87 0 0 0-40.81 0l-40.66 40.63 51.52 51.52c27.06-9.14 52.68 16.77 43.39 43.68l49.66 49.66c34.23-11.8 61.18 31 35.47 56.69-26.49 26.49-70.21-2.87-56-37.34L240.22 199v121.85c25.3 12.54 22.26 41.85 9.08 55a34.34 34.34 0 0 1-48.55 0c-17.57-17.6-11.07-46.91 11.25-56v-123c-20.8-8.51-24.6-30.74-18.64-45L142.57 101 8.45 235.14a28.86 28.86 0 0 0 0 40.81l195.61 195.6a28.86 28.86 0 0 0 40.8 0l194.69-194.69a28.86 28.86 0 0 0 0-40.81"/></svg>
  </div>
  <div class="md-source__repository">
    pytorch_widedeep
  </div>
</a>
    </div>
  
  <ul class="md-nav__list" data-md-scrollfix>
    
      
      
  
  
  
  
    <li class="md-nav__item">
      <a href="../index.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Home
  </span>
  

      </a>
    </li>
  

    
      
      
  
  
  
  
    <li class="md-nav__item">
      <a href="../installation.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Installation
  </span>
  

      </a>
    </li>
  

    
      
      
  
  
  
  
    <li class="md-nav__item">
      <a href="../quick_start.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Quick Start
  </span>
  

      </a>
    </li>
  

    
      
      
  
  
    
  
  
  
    
    
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
    
    
      
        
        
      
      
    
    
    <li class="md-nav__item md-nav__item--active md-nav__item--section md-nav__item--nested">
      
        
        
        <input class="md-nav__toggle md-toggle " type="checkbox" id="__nav_4" checked>
        
          
          <label class="md-nav__link" for="__nav_4" id="__nav_4_label" tabindex="">
            
  
  <span class="md-ellipsis">
    Pytorch-widedeep
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="1" aria-labelledby="__nav_4_label" aria-expanded="true">
          <label class="md-nav__title" for="__nav_4">
            <span class="md-nav__icon md-icon"></span>
            Pytorch-widedeep
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    
    
      
        
          
        
      
        
      
        
      
        
      
        
      
    
    
      
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
          
        
        <input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_4_1" >
        
          
          
          <div class="md-nav__link md-nav__container">
            <a href="utils/index.html" class="md-nav__link ">
              
  
  <span class="md-ellipsis">
    Utils
  </span>
  

            </a>
            
              
              <label class="md-nav__link " for="__nav_4_1" id="__nav_4_1_label" tabindex="0">
                <span class="md-nav__icon md-icon"></span>
              </label>
            
          </div>
        
        <nav class="md-nav" data-md-level="2" aria-labelledby="__nav_4_1_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_4_1">
            <span class="md-nav__icon md-icon"></span>
            Utils
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="utils/deeptabular_utils.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Deeptabular utils
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="utils/fastai_transforms.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Fastai transforms
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="utils/image_utils.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Image utils
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="utils/text_utils.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Text utils
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="preprocessing.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Preprocessing
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="load_from_folder.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Load From Folder
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="model_components.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Model Components
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="the_rec_module.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    The Rec Module
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="bayesian_models.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Bayesian models
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="losses.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Losses
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
    
  
  
  
    <li class="md-nav__item md-nav__item--active">
      
      <input class="md-nav__toggle md-toggle" type="checkbox" id="__toc">
      
      
        
      
      
        <label class="md-nav__link md-nav__link--active" for="__toc">
          
  
  <span class="md-ellipsis">
    Metrics
  </span>
  

          <span class="md-nav__icon md-icon"></span>
        </label>
      
      <a href="metrics.html" class="md-nav__link md-nav__link--active">
        
  
  <span class="md-ellipsis">
    Metrics
  </span>
  

      </a>
      
        

<nav class="md-nav md-nav--secondary" aria-label="Table of contents">
  
  
  
    
  
  
    <label class="md-nav__title" for="__toc">
      <span class="md-nav__icon md-icon"></span>
      Table of contents
    </label>
    <ul class="md-nav__list" data-md-component="toc" data-md-scrollfix>
      
        <li class="md-nav__item">
  <a href="#pytorch_widedeep.metrics.Accuracy" class="md-nav__link">
    <span class="md-ellipsis">
      Accuracy
    </span>
  </a>
  
</li>
      
        <li class="md-nav__item">
  <a href="#pytorch_widedeep.metrics.Precision" class="md-nav__link">
    <span class="md-ellipsis">
      Precision
    </span>
  </a>
  
</li>
      
        <li class="md-nav__item">
  <a href="#pytorch_widedeep.metrics.Recall" class="md-nav__link">
    <span class="md-ellipsis">
      Recall
    </span>
  </a>
  
</li>
      
        <li class="md-nav__item">
  <a href="#pytorch_widedeep.metrics.FBetaScore" class="md-nav__link">
    <span class="md-ellipsis">
      FBetaScore
    </span>
  </a>
  
</li>
      
        <li class="md-nav__item">
  <a href="#pytorch_widedeep.metrics.F1Score" class="md-nav__link">
    <span class="md-ellipsis">
      F1Score
    </span>
  </a>
  
</li>
      
        <li class="md-nav__item">
  <a href="#pytorch_widedeep.metrics.R2Score" class="md-nav__link">
    <span class="md-ellipsis">
      R2Score
    </span>
  </a>
  
</li>
      
    </ul>
  
</nav>
      
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="dataloaders.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Dataloaders
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="callbacks.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Callbacks
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="trainer.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Trainer
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="bayesian_trainer.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Bayesian Trainer
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="self_supervised_pretraining.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Self Supervised Pretraining
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="tab2vec.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Tab2Vec
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

    
      
      
  
  
  
  
    
    
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
        
      
    
    
      
      
    
    
    <li class="md-nav__item md-nav__item--nested">
      
        
        
          
        
        <input class="md-nav__toggle md-toggle md-toggle--indeterminate" type="checkbox" id="__nav_5" >
        
          
          <label class="md-nav__link" for="__nav_5" id="__nav_5_label" tabindex="0">
            
  
  <span class="md-ellipsis">
    Examples
  </span>
  

            <span class="md-nav__icon md-icon"></span>
          </label>
        
        <nav class="md-nav" data-md-level="1" aria-labelledby="__nav_5_label" aria-expanded="false">
          <label class="md-nav__title" for="__nav_5">
            <span class="md-nav__icon md-icon"></span>
            Examples
          </label>
          <ul class="md-nav__list" data-md-scrollfix>
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/01_preprocessors_and_utils.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    01_preprocessors_and_utils
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/02_model_components.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    02_model_components
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/03_binary_classification_with_defaults.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    03_binary_classification_with_defaults
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/04_regression_with_images_and_text.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    04_regression_with_images_and_text
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/05_save_and_load_model_and_artifacts.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    05_save_and_load_model_and_artifacts
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/06_finetune_and_warmup.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    06_finetune_and_warmup
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/07_custom_components.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    07_custom_components
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/08_custom_dataLoader_imbalanced_dataset.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    08_custom_dataLoader_imbalanced_dataset
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/09_extracting_embeddings.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    09_extracting_embeddings
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/10_3rd_party_integration-RayTune_WnB.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    10_3rd_party_integration-RayTune_WnB
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/11_auc_multiclass.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    11_auc_multiclass
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/12_ZILNLoss_origkeras_vs_pytorch_widedeep.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    12_ZILNLoss_origkeras_vs_pytorch_widedeep
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/13_model_uncertainty_prediction.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    13_model_uncertainty_prediction
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/14_bayesian_models.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    14_bayesian_models
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/15_Self_Supervised_Pretraning_pt1.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    15_Self-Supervised Pre-Training pt 1
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/15_Self_Supervised_Pretraning_pt2.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    15_Self-Supervised Pre-Training pt 2
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/16_Usign_a_custom_hugging_face_model.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    16_Usign-a-custom-hugging-face-model
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/17_feature_importance_via_attention_weights.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    17_feature_importance_via_attention_weights
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/18_wide_and_deep_for_recsys_pt1.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    18_wide_and_deep_for_recsys_pt1
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/18_wide_and_deep_for_recsys_pt2.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    18_wide_and_deep_for_recsys_pt2
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/19_load_from_folder_functionality.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    19_load_from_folder_functionality
  </span>
  

      </a>
    </li>
  

              
            
              
                
  
  
  
  
    <li class="md-nav__item">
      <a href="../examples/20_Using_huggingface_within_widedeep.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    20-Using-huggingface-within-widedeep
  </span>
  

      </a>
    </li>
  

              
            
          </ul>
        </nav>
      
    </li>
  

    
      
      
  
  
  
  
    <li class="md-nav__item">
      <a href="../contributing.html" class="md-nav__link">
        
  
  <span class="md-ellipsis">
    Contributing
  </span>
  

      </a>
    </li>
  

    
  </ul>
</nav>
                  </div>
                </div>
              </div>
            
            
          
          
            <div class="md-content" data-md-component="content">
              <article class="md-content__inner md-typeset">
                
                  

  
  


<h1 id="metrics">Metrics<a class="headerlink" href="#metrics" title="Permanent link">&para;</a></h1>
<hr />
<p><img alt="ℹ️" class="emojione" src="https://cdnjs.cloudflare.com/ajax/libs/emojione/2.2.7/assets/png/2139.png" title=":information_source:" /> <strong>NOTE</strong>: metrics in this module expect the predictions
 and ground truth to have the same dimensions for regression and binary
 classification problems: <span class="arithmatex">\((N_{samples}, 1)\)</span>. In the case of multiclass
 classification problems the ground truth is expected to be a 1D tensor with
 the corresponding classes. See Examples below</p>
<hr />
<p>We have added the possibility of using the metrics available at the
<a href="https://torchmetrics.readthedocs.io/en/latest/">torchmetrics</a> library. Note
that this library is still in its early versions and therefore this option
should be used with caution. To use <code>torchmetrics</code> simply import them and
use them as any of the <code>pytorch-widedeep</code> metrics described below.</p>
<div class="highlight"><pre><span></span><code><span class="kn">from</span> <span class="nn">torchmetrics</span> <span class="kn">import</span> <span class="n">Accuracy</span><span class="p">,</span> <span class="n">Precision</span>

<span class="n">accuracy</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">(</span><span class="n">average</span><span class="o">=</span><span class="kc">None</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="n">precision</span> <span class="o">=</span> <span class="n">Precision</span><span class="p">(</span><span class="n">average</span><span class="o">=</span><span class="s1">&#39;micro&#39;</span><span class="p">,</span> <span class="n">num_classes</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>

<span class="n">trainer</span> <span class="o">=</span> <span class="n">Trainer</span><span class="p">(</span><span class="n">model</span><span class="p">,</span> <span class="n">objective</span><span class="o">=</span><span class="s2">&quot;binary&quot;</span><span class="p">,</span> <span class="n">metrics</span><span class="o">=</span><span class="p">[</span><span class="n">accuracy</span><span class="p">,</span> <span class="n">precision</span><span class="p">])</span>
</code></pre></div>
<p>A functioning example for <code>pytorch-widedeep</code> using <code>torchmetrics</code> can be
found in the <a href="https://github.com/jrzaurin/pytorch-widedeep/blob/master/examples">Examples folder</a></p>
<p><img alt="ℹ️" class="emojione" src="https://cdnjs.cloudflare.com/ajax/libs/emojione/2.2.7/assets/png/2139.png" title=":information_source:" /> <strong>NOTE</strong>: the forward method for all metrics in this
 module takes two tensors, <code>y_pred</code> and <code>y_true</code> (in that order). Therefore,
 we do not include the method in the documentation.</p>


<div class="doc doc-object doc-class">



<h2 id="pytorch_widedeep.metrics.Accuracy" class="doc doc-heading">
            <span class="doc doc-object-name doc-class-name">Accuracy</span>


<a href="#pytorch_widedeep.metrics.Accuracy" class="headerlink" title="Permanent link">&para;</a></h2>


    <div class="doc doc-contents first">
            <p class="doc doc-class-bases">
              Bases: <code><span title="pytorch_widedeep.metrics.Metric">Metric</span></code></p>


        <p>Class to calculate the accuracy for both binary and categorical problems</p>


<p><span class="doc-section-title">Parameters:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
          <th>Default</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td>
                <code>top_k</code>
            </td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>Accuracy will be computed using the top k most likely classes in
multiclass problems</p>
              </div>
            </td>
            <td>
                  <code>1</code>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Examples:</span></p>
    <div class="highlight"><pre><span></span><code><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">torch</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.metrics</span> <span class="kn">import</span> <span class="n">Accuracy</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">acc</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">]])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">acc</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.5)</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">acc</span> <span class="o">=</span> <span class="n">Accuracy</span><span class="p">(</span><span class="n">top_k</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">]])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">acc</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.66666667)</span>
</code></pre></div>






              <details class="quote">
                <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
                <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal"> 48</span>
<span class="normal"> 49</span>
<span class="normal"> 50</span>
<span class="normal"> 51</span>
<span class="normal"> 52</span>
<span class="normal"> 53</span>
<span class="normal"> 54</span>
<span class="normal"> 55</span>
<span class="normal"> 56</span>
<span class="normal"> 57</span>
<span class="normal"> 58</span>
<span class="normal"> 59</span>
<span class="normal"> 60</span>
<span class="normal"> 61</span>
<span class="normal"> 62</span>
<span class="normal"> 63</span>
<span class="normal"> 64</span>
<span class="normal"> 65</span>
<span class="normal"> 66</span>
<span class="normal"> 67</span>
<span class="normal"> 68</span>
<span class="normal"> 69</span>
<span class="normal"> 70</span>
<span class="normal"> 71</span>
<span class="normal"> 72</span>
<span class="normal"> 73</span>
<span class="normal"> 74</span>
<span class="normal"> 75</span>
<span class="normal"> 76</span>
<span class="normal"> 77</span>
<span class="normal"> 78</span>
<span class="normal"> 79</span>
<span class="normal"> 80</span>
<span class="normal"> 81</span>
<span class="normal"> 82</span>
<span class="normal"> 83</span>
<span class="normal"> 84</span>
<span class="normal"> 85</span>
<span class="normal"> 86</span>
<span class="normal"> 87</span>
<span class="normal"> 88</span>
<span class="normal"> 89</span>
<span class="normal"> 90</span>
<span class="normal"> 91</span>
<span class="normal"> 92</span>
<span class="normal"> 93</span>
<span class="normal"> 94</span>
<span class="normal"> 95</span>
<span class="normal"> 96</span>
<span class="normal"> 97</span>
<span class="normal"> 98</span>
<span class="normal"> 99</span>
<span class="normal">100</span>
<span class="normal">101</span>
<span class="normal">102</span>
<span class="normal">103</span>
<span class="normal">104</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">Accuracy</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Class to calculate the accuracy for both binary and categorical problems</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    top_k: int, default = 1</span>
<span class="sd">        Accuracy will be computed using the top k most likely classes in</span>
<span class="sd">        multiclass problems</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">    &gt;&gt;&gt; import torch</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.metrics import Accuracy</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; acc = Accuracy()</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; acc(y_pred, y_true)</span>
<span class="sd">    array(0.5)</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; acc = Accuracy(top_k=2)</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 2])</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.3, 0.5, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])</span>
<span class="sd">    &gt;&gt;&gt; acc(y_pred, y_true)</span>
<span class="sd">    array(0.66666667)</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">top_k</span><span class="p">:</span> <span class="nb">int</span> <span class="o">=</span> <span class="mi">1</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Accuracy</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">top_k</span> <span class="o">=</span> <span class="n">top_k</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">correct_count</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">total_count</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">&quot;acc&quot;</span>

    <span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        resets counters to 0</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">correct_count</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">total_count</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y_true</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
        <span class="n">num_classes</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">num_classes</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">y_pred</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">round</span><span class="p">()</span>
            <span class="n">y_true</span> <span class="o">=</span> <span class="n">y_true</span>
        <span class="k">elif</span> <span class="n">num_classes</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">y_pred</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">top_k</span><span class="p">,</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span>
            <span class="n">y_true</span> <span class="o">=</span> <span class="n">y_true</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span><span class="o">.</span><span class="n">expand_as</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">correct_count</span> <span class="o">+=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">eq</span><span class="p">(</span><span class="n">y_true</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>  <span class="c1"># type: ignore[assignment]</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">total_count</span> <span class="o">+=</span> <span class="nb">len</span><span class="p">(</span><span class="n">y_pred</span><span class="p">)</span>
        <span class="n">accuracy</span> <span class="o">=</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">correct_count</span><span class="p">)</span> <span class="o">/</span> <span class="nb">float</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">total_count</span><span class="p">)</span>
        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">accuracy</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
              </details>



  <div class="doc doc-children">









<div class="doc doc-object doc-function">


<h3 id="pytorch_widedeep.metrics.Accuracy.reset" class="doc doc-heading">
            <span class="doc doc-object-name doc-function-name">reset</span>


<a href="#pytorch_widedeep.metrics.Accuracy.reset" class="headerlink" title="Permanent link">&para;</a></h3>
<div class="doc-signature highlight"><pre><span></span><code><span class="nf">reset</span><span class="p">()</span>
</code></pre></div>

    <div class="doc doc-contents ">

        <p>resets counters to 0</p>

            <details class="quote">
              <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
              <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">84</span>
<span class="normal">85</span>
<span class="normal">86</span>
<span class="normal">87</span>
<span class="normal">88</span>
<span class="normal">89</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    resets counters to 0</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">correct_count</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">total_count</span> <span class="o">=</span> <span class="mi">0</span>
</code></pre></div></td></tr></table></div>
            </details>
    </div>

</div>



  </div>

    </div>

</div>

<div class="doc doc-object doc-class">



<h2 id="pytorch_widedeep.metrics.Precision" class="doc doc-heading">
            <span class="doc doc-object-name doc-class-name">Precision</span>


<a href="#pytorch_widedeep.metrics.Precision" class="headerlink" title="Permanent link">&para;</a></h2>


    <div class="doc doc-contents first">
            <p class="doc doc-class-bases">
              Bases: <code><span title="pytorch_widedeep.metrics.Metric">Metric</span></code></p>


        <p>Class to calculate the precision for both binary and categorical problems</p>


<p><span class="doc-section-title">Parameters:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
          <th>Default</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td>
                <code>average</code>
            </td>
            <td>
                  <code>bool</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>This applies only to multiclass problems. if <code>True</code> calculate
precision for each label, and finds their unweighted mean.</p>
              </div>
            </td>
            <td>
                  <code>True</code>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Examples:</span></p>
    <div class="highlight"><pre><span></span><code><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">torch</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.metrics</span> <span class="kn">import</span> <span class="n">Precision</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">prec</span> <span class="o">=</span> <span class="n">Precision</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">]])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">prec</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.5)</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">prec</span> <span class="o">=</span> <span class="n">Precision</span><span class="p">(</span><span class="n">average</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">]])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">prec</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.33333334)</span>
</code></pre></div>






              <details class="quote">
                <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
                <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">107</span>
<span class="normal">108</span>
<span class="normal">109</span>
<span class="normal">110</span>
<span class="normal">111</span>
<span class="normal">112</span>
<span class="normal">113</span>
<span class="normal">114</span>
<span class="normal">115</span>
<span class="normal">116</span>
<span class="normal">117</span>
<span class="normal">118</span>
<span class="normal">119</span>
<span class="normal">120</span>
<span class="normal">121</span>
<span class="normal">122</span>
<span class="normal">123</span>
<span class="normal">124</span>
<span class="normal">125</span>
<span class="normal">126</span>
<span class="normal">127</span>
<span class="normal">128</span>
<span class="normal">129</span>
<span class="normal">130</span>
<span class="normal">131</span>
<span class="normal">132</span>
<span class="normal">133</span>
<span class="normal">134</span>
<span class="normal">135</span>
<span class="normal">136</span>
<span class="normal">137</span>
<span class="normal">138</span>
<span class="normal">139</span>
<span class="normal">140</span>
<span class="normal">141</span>
<span class="normal">142</span>
<span class="normal">143</span>
<span class="normal">144</span>
<span class="normal">145</span>
<span class="normal">146</span>
<span class="normal">147</span>
<span class="normal">148</span>
<span class="normal">149</span>
<span class="normal">150</span>
<span class="normal">151</span>
<span class="normal">152</span>
<span class="normal">153</span>
<span class="normal">154</span>
<span class="normal">155</span>
<span class="normal">156</span>
<span class="normal">157</span>
<span class="normal">158</span>
<span class="normal">159</span>
<span class="normal">160</span>
<span class="normal">161</span>
<span class="normal">162</span>
<span class="normal">163</span>
<span class="normal">164</span>
<span class="normal">165</span>
<span class="normal">166</span>
<span class="normal">167</span>
<span class="normal">168</span>
<span class="normal">169</span>
<span class="normal">170</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">Precision</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Class to calculate the precision for both binary and categorical problems</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    average: bool, default = True</span>
<span class="sd">        This applies only to multiclass problems. if ``True`` calculate</span>
<span class="sd">        precision for each label, and finds their unweighted mean.</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">    &gt;&gt;&gt; import torch</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.metrics import Precision</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; prec = Precision()</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; prec(y_pred, y_true)</span>
<span class="sd">    array(0.5)</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; prec = Precision(average=True)</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 2])</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])</span>
<span class="sd">    &gt;&gt;&gt; prec(y_pred, y_true)</span>
<span class="sd">    array(0.33333334)</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">average</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Precision</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">average</span> <span class="o">=</span> <span class="n">average</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">all_positives</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="mf">1e-20</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">&quot;prec&quot;</span>

    <span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        resets counters to 0</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">all_positives</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y_true</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
        <span class="n">num_class</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">num_class</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">y_pred</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">round</span><span class="p">()</span>
            <span class="n">y_true</span> <span class="o">=</span> <span class="n">y_true</span>
        <span class="k">elif</span> <span class="n">num_class</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_class</span><span class="p">)[</span><span class="n">y_true</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">long</span><span class="p">()]</span>
            <span class="n">y_pred</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
            <span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_class</span><span class="p">)[</span><span class="n">y_pred</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">long</span><span class="p">()]</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">+=</span> <span class="p">(</span><span class="n">y_true</span> <span class="o">*</span> <span class="n">y_pred</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>  <span class="c1"># type:ignore</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">all_positives</span> <span class="o">+=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>  <span class="c1"># type:ignore</span>

        <span class="n">precision</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">all_positives</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">average</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">precision</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>  <span class="c1"># type:ignore</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">precision</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>  <span class="c1"># type: ignore[attr-defined]</span>
</code></pre></div></td></tr></table></div>
              </details>



  <div class="doc doc-children">









<div class="doc doc-object doc-function">


<h3 id="pytorch_widedeep.metrics.Precision.reset" class="doc doc-heading">
            <span class="doc doc-object-name doc-function-name">reset</span>


<a href="#pytorch_widedeep.metrics.Precision.reset" class="headerlink" title="Permanent link">&para;</a></h3>
<div class="doc-signature highlight"><pre><span></span><code><span class="nf">reset</span><span class="p">()</span>
</code></pre></div>

    <div class="doc doc-contents ">

        <p>resets counters to 0</p>

            <details class="quote">
              <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
              <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">144</span>
<span class="normal">145</span>
<span class="normal">146</span>
<span class="normal">147</span>
<span class="normal">148</span>
<span class="normal">149</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    resets counters to 0</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">all_positives</span> <span class="o">=</span> <span class="mi">0</span>
</code></pre></div></td></tr></table></div>
            </details>
    </div>

</div>



  </div>

    </div>

</div>

<div class="doc doc-object doc-class">



<h2 id="pytorch_widedeep.metrics.Recall" class="doc doc-heading">
            <span class="doc doc-object-name doc-class-name">Recall</span>


<a href="#pytorch_widedeep.metrics.Recall" class="headerlink" title="Permanent link">&para;</a></h2>


    <div class="doc doc-contents first">
            <p class="doc doc-class-bases">
              Bases: <code><span title="pytorch_widedeep.metrics.Metric">Metric</span></code></p>


        <p>Class to calculate the recall for both binary and categorical problems</p>


<p><span class="doc-section-title">Parameters:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
          <th>Default</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td>
                <code>average</code>
            </td>
            <td>
                  <code>bool</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>This applies only to multiclass problems. if <code>True</code> calculate recall
for each label, and finds their unweighted mean.</p>
              </div>
            </td>
            <td>
                  <code>True</code>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Examples:</span></p>
    <div class="highlight"><pre><span></span><code><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">torch</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.metrics</span> <span class="kn">import</span> <span class="n">Recall</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">rec</span> <span class="o">=</span> <span class="n">Recall</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">]])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">rec</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.5)</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">rec</span> <span class="o">=</span> <span class="n">Recall</span><span class="p">(</span><span class="n">average</span><span class="o">=</span><span class="kc">True</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">]])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">rec</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.33333334)</span>
</code></pre></div>






              <details class="quote">
                <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
                <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">173</span>
<span class="normal">174</span>
<span class="normal">175</span>
<span class="normal">176</span>
<span class="normal">177</span>
<span class="normal">178</span>
<span class="normal">179</span>
<span class="normal">180</span>
<span class="normal">181</span>
<span class="normal">182</span>
<span class="normal">183</span>
<span class="normal">184</span>
<span class="normal">185</span>
<span class="normal">186</span>
<span class="normal">187</span>
<span class="normal">188</span>
<span class="normal">189</span>
<span class="normal">190</span>
<span class="normal">191</span>
<span class="normal">192</span>
<span class="normal">193</span>
<span class="normal">194</span>
<span class="normal">195</span>
<span class="normal">196</span>
<span class="normal">197</span>
<span class="normal">198</span>
<span class="normal">199</span>
<span class="normal">200</span>
<span class="normal">201</span>
<span class="normal">202</span>
<span class="normal">203</span>
<span class="normal">204</span>
<span class="normal">205</span>
<span class="normal">206</span>
<span class="normal">207</span>
<span class="normal">208</span>
<span class="normal">209</span>
<span class="normal">210</span>
<span class="normal">211</span>
<span class="normal">212</span>
<span class="normal">213</span>
<span class="normal">214</span>
<span class="normal">215</span>
<span class="normal">216</span>
<span class="normal">217</span>
<span class="normal">218</span>
<span class="normal">219</span>
<span class="normal">220</span>
<span class="normal">221</span>
<span class="normal">222</span>
<span class="normal">223</span>
<span class="normal">224</span>
<span class="normal">225</span>
<span class="normal">226</span>
<span class="normal">227</span>
<span class="normal">228</span>
<span class="normal">229</span>
<span class="normal">230</span>
<span class="normal">231</span>
<span class="normal">232</span>
<span class="normal">233</span>
<span class="normal">234</span>
<span class="normal">235</span>
<span class="normal">236</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">Recall</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Class to calculate the recall for both binary and categorical problems</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    average: bool, default = True</span>
<span class="sd">        This applies only to multiclass problems. if ``True`` calculate recall</span>
<span class="sd">        for each label, and finds their unweighted mean.</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">    &gt;&gt;&gt; import torch</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.metrics import Recall</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; rec = Recall()</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; rec(y_pred, y_true)</span>
<span class="sd">    array(0.5)</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; rec = Recall(average=True)</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 2])</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])</span>
<span class="sd">    &gt;&gt;&gt; rec(y_pred, y_true)</span>
<span class="sd">    array(0.33333334)</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">average</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">Recall</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">average</span> <span class="o">=</span> <span class="n">average</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">actual_positives</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="mf">1e-20</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">&quot;rec&quot;</span>

    <span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        resets counters to 0</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">actual_positives</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y_true</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
        <span class="n">num_class</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">size</span><span class="p">(</span><span class="mi">1</span><span class="p">)</span>

        <span class="k">if</span> <span class="n">num_class</span> <span class="o">==</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">y_pred</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">round</span><span class="p">()</span>
            <span class="n">y_true</span> <span class="o">=</span> <span class="n">y_true</span>
        <span class="k">elif</span> <span class="n">num_class</span> <span class="o">&gt;</span> <span class="mi">1</span><span class="p">:</span>
            <span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_class</span><span class="p">)[</span><span class="n">y_true</span><span class="o">.</span><span class="n">squeeze</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">long</span><span class="p">()]</span>
            <span class="n">y_pred</span> <span class="o">=</span> <span class="n">y_pred</span><span class="o">.</span><span class="n">topk</span><span class="p">(</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)[</span><span class="mi">1</span><span class="p">]</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">)</span>
            <span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">eye</span><span class="p">(</span><span class="n">num_class</span><span class="p">)[</span><span class="n">y_pred</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">long</span><span class="p">()]</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">+=</span> <span class="p">(</span><span class="n">y_true</span> <span class="o">*</span> <span class="n">y_pred</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>  <span class="c1"># type: ignore</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">actual_positives</span> <span class="o">+=</span> <span class="n">y_true</span><span class="o">.</span><span class="n">sum</span><span class="p">(</span><span class="n">dim</span><span class="o">=</span><span class="mi">0</span><span class="p">)</span>  <span class="c1"># type: ignore</span>

        <span class="n">recall</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">/</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">actual_positives</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">average</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">recall</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>  <span class="c1"># type:ignore</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">recall</span><span class="o">.</span><span class="n">detach</span><span class="p">()</span><span class="o">.</span><span class="n">cpu</span><span class="p">()</span><span class="o">.</span><span class="n">numpy</span><span class="p">()</span>  <span class="c1"># type: ignore[attr-defined]</span>
</code></pre></div></td></tr></table></div>
              </details>



  <div class="doc doc-children">









<div class="doc doc-object doc-function">


<h3 id="pytorch_widedeep.metrics.Recall.reset" class="doc doc-heading">
            <span class="doc doc-object-name doc-function-name">reset</span>


<a href="#pytorch_widedeep.metrics.Recall.reset" class="headerlink" title="Permanent link">&para;</a></h3>
<div class="doc-signature highlight"><pre><span></span><code><span class="nf">reset</span><span class="p">()</span>
</code></pre></div>

    <div class="doc doc-contents ">

        <p>resets counters to 0</p>

            <details class="quote">
              <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
              <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">210</span>
<span class="normal">211</span>
<span class="normal">212</span>
<span class="normal">213</span>
<span class="normal">214</span>
<span class="normal">215</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    resets counters to 0</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">true_positives</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">actual_positives</span> <span class="o">=</span> <span class="mi">0</span>
</code></pre></div></td></tr></table></div>
            </details>
    </div>

</div>



  </div>

    </div>

</div>

<div class="doc doc-object doc-class">



<h2 id="pytorch_widedeep.metrics.FBetaScore" class="doc doc-heading">
            <span class="doc doc-object-name doc-class-name">FBetaScore</span>


<a href="#pytorch_widedeep.metrics.FBetaScore" class="headerlink" title="Permanent link">&para;</a></h2>


    <div class="doc doc-contents first">
            <p class="doc doc-class-bases">
              Bases: <code><span title="pytorch_widedeep.metrics.Metric">Metric</span></code></p>


        <p>Class to calculate the fbeta score for both binary and categorical problems</p>
<div class="arithmatex">\[
F_{\beta} = ((1 + {\beta}^2) * \frac{(precision * recall)}{({\beta}^2 * precision + recall)}
\]</div>


<p><span class="doc-section-title">Parameters:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
          <th>Default</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td>
                <code>beta</code>
            </td>
            <td>
                  <code>int</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>Coefficient to control the balance between precision and recall</p>
              </div>
            </td>
            <td>
                <em>required</em>
            </td>
          </tr>
          <tr class="doc-section-item">
            <td>
                <code>average</code>
            </td>
            <td>
                  <code>bool</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>This applies only to multiclass problems. if <code>True</code> calculate fbeta
for each label, and find their unweighted mean.</p>
              </div>
            </td>
            <td>
                  <code>True</code>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Examples:</span></p>
    <div class="highlight"><pre><span></span><code><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">torch</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.metrics</span> <span class="kn">import</span> <span class="n">FBetaScore</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">fbeta</span> <span class="o">=</span> <span class="n">FBetaScore</span><span class="p">(</span><span class="n">beta</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">]])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">fbeta</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.5)</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">fbeta</span> <span class="o">=</span> <span class="n">FBetaScore</span><span class="p">(</span><span class="n">beta</span><span class="o">=</span><span class="mi">2</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">]])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">fbeta</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.33333334)</span>
</code></pre></div>






              <details class="quote">
                <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
                <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">239</span>
<span class="normal">240</span>
<span class="normal">241</span>
<span class="normal">242</span>
<span class="normal">243</span>
<span class="normal">244</span>
<span class="normal">245</span>
<span class="normal">246</span>
<span class="normal">247</span>
<span class="normal">248</span>
<span class="normal">249</span>
<span class="normal">250</span>
<span class="normal">251</span>
<span class="normal">252</span>
<span class="normal">253</span>
<span class="normal">254</span>
<span class="normal">255</span>
<span class="normal">256</span>
<span class="normal">257</span>
<span class="normal">258</span>
<span class="normal">259</span>
<span class="normal">260</span>
<span class="normal">261</span>
<span class="normal">262</span>
<span class="normal">263</span>
<span class="normal">264</span>
<span class="normal">265</span>
<span class="normal">266</span>
<span class="normal">267</span>
<span class="normal">268</span>
<span class="normal">269</span>
<span class="normal">270</span>
<span class="normal">271</span>
<span class="normal">272</span>
<span class="normal">273</span>
<span class="normal">274</span>
<span class="normal">275</span>
<span class="normal">276</span>
<span class="normal">277</span>
<span class="normal">278</span>
<span class="normal">279</span>
<span class="normal">280</span>
<span class="normal">281</span>
<span class="normal">282</span>
<span class="normal">283</span>
<span class="normal">284</span>
<span class="normal">285</span>
<span class="normal">286</span>
<span class="normal">287</span>
<span class="normal">288</span>
<span class="normal">289</span>
<span class="normal">290</span>
<span class="normal">291</span>
<span class="normal">292</span>
<span class="normal">293</span>
<span class="normal">294</span>
<span class="normal">295</span>
<span class="normal">296</span>
<span class="normal">297</span>
<span class="normal">298</span>
<span class="normal">299</span>
<span class="normal">300</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">FBetaScore</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Class to calculate the fbeta score for both binary and categorical problems</span>

<span class="sd">    $$</span>
<span class="sd">    F_{\beta} = ((1 + {\beta}^2) * \frac{(precision * recall)}{({\beta}^2 * precision + recall)}</span>
<span class="sd">    $$</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    beta: int</span>
<span class="sd">        Coefficient to control the balance between precision and recall</span>
<span class="sd">    average: bool, default = True</span>
<span class="sd">        This applies only to multiclass problems. if ``True`` calculate fbeta</span>
<span class="sd">        for each label, and find their unweighted mean.</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">    &gt;&gt;&gt; import torch</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.metrics import FBetaScore</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; fbeta = FBetaScore(beta=2)</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; fbeta(y_pred, y_true)</span>
<span class="sd">    array(0.5)</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; fbeta = FBetaScore(beta=2)</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 2])</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])</span>
<span class="sd">    &gt;&gt;&gt; fbeta(y_pred, y_true)</span>
<span class="sd">    array(0.33333334)</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">beta</span><span class="p">:</span> <span class="nb">int</span><span class="p">,</span> <span class="n">average</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">FBetaScore</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">beta</span> <span class="o">=</span> <span class="n">beta</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">average</span> <span class="o">=</span> <span class="n">average</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">precision</span> <span class="o">=</span> <span class="n">Precision</span><span class="p">(</span><span class="n">average</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">recall</span> <span class="o">=</span> <span class="n">Recall</span><span class="p">(</span><span class="n">average</span><span class="o">=</span><span class="kc">False</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">eps</span> <span class="o">=</span> <span class="mf">1e-20</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">&quot;&quot;</span><span class="o">.</span><span class="n">join</span><span class="p">([</span><span class="s2">&quot;f&quot;</span><span class="p">,</span> <span class="nb">str</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">beta</span><span class="p">)])</span>

    <span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        resets precision and recall</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">precision</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">recall</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>

    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y_true</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
        <span class="n">prec</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">precision</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
        <span class="n">rec</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">recall</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
        <span class="n">beta2</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">beta</span><span class="o">**</span><span class="mi">2</span>

        <span class="n">fbeta</span> <span class="o">=</span> <span class="p">((</span><span class="mi">1</span> <span class="o">+</span> <span class="n">beta2</span><span class="p">)</span> <span class="o">*</span> <span class="n">prec</span> <span class="o">*</span> <span class="n">rec</span><span class="p">)</span> <span class="o">/</span> <span class="p">(</span><span class="n">beta2</span> <span class="o">*</span> <span class="n">prec</span> <span class="o">+</span> <span class="n">rec</span> <span class="o">+</span> <span class="bp">self</span><span class="o">.</span><span class="n">eps</span><span class="p">)</span>

        <span class="k">if</span> <span class="bp">self</span><span class="o">.</span><span class="n">average</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">(</span><span class="n">fbeta</span><span class="o">.</span><span class="n">mean</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">())</span>  <span class="c1"># type: ignore[attr-defined]</span>
        <span class="k">else</span><span class="p">:</span>
            <span class="k">return</span> <span class="n">fbeta</span>
</code></pre></div></td></tr></table></div>
              </details>



  <div class="doc doc-children">









<div class="doc doc-object doc-function">


<h3 id="pytorch_widedeep.metrics.FBetaScore.reset" class="doc doc-heading">
            <span class="doc doc-object-name doc-function-name">reset</span>


<a href="#pytorch_widedeep.metrics.FBetaScore.reset" class="headerlink" title="Permanent link">&para;</a></h3>
<div class="doc-signature highlight"><pre><span></span><code><span class="nf">reset</span><span class="p">()</span>
</code></pre></div>

    <div class="doc doc-contents ">

        <p>resets precision and recall</p>

            <details class="quote">
              <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
              <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">283</span>
<span class="normal">284</span>
<span class="normal">285</span>
<span class="normal">286</span>
<span class="normal">287</span>
<span class="normal">288</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    resets precision and recall</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">precision</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">recall</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
</code></pre></div></td></tr></table></div>
            </details>
    </div>

</div>



  </div>

    </div>

</div>

<div class="doc doc-object doc-class">



<h2 id="pytorch_widedeep.metrics.F1Score" class="doc doc-heading">
            <span class="doc doc-object-name doc-class-name">F1Score</span>


<a href="#pytorch_widedeep.metrics.F1Score" class="headerlink" title="Permanent link">&para;</a></h2>


    <div class="doc doc-contents first">
            <p class="doc doc-class-bases">
              Bases: <code><span title="pytorch_widedeep.metrics.Metric">Metric</span></code></p>


        <p>Class to calculate the f1 score for both binary and categorical problems</p>


<p><span class="doc-section-title">Parameters:</span></p>
    <table>
      <thead>
        <tr>
          <th>Name</th>
          <th>Type</th>
          <th>Description</th>
          <th>Default</th>
        </tr>
      </thead>
      <tbody>
          <tr class="doc-section-item">
            <td>
                <code>average</code>
            </td>
            <td>
                  <code>bool</code>
            </td>
            <td>
              <div class="doc-md-description">
                <p>This applies only to multiclass problems. if <code>True</code> calculate f1 for
each label, and find their unweighted mean.</p>
              </div>
            </td>
            <td>
                  <code>True</code>
            </td>
          </tr>
      </tbody>
    </table>


<p><span class="doc-section-title">Examples:</span></p>
    <div class="highlight"><pre><span></span><code><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">torch</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.metrics</span> <span class="kn">import</span> <span class="n">F1Score</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">f1</span> <span class="o">=</span> <span class="n">F1Score</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.3</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">,</span> <span class="mf">0.6</span><span class="p">,</span> <span class="mf">0.7</span><span class="p">]])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">f1</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.5)</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">f1</span> <span class="o">=</span> <span class="n">F1Score</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">0</span><span class="p">,</span> <span class="mi">1</span><span class="p">,</span> <span class="mi">2</span><span class="p">])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([[</span><span class="mf">0.7</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.2</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.8</span><span class="p">],</span> <span class="p">[</span><span class="mf">0.1</span><span class="p">,</span> <span class="mf">0.5</span><span class="p">,</span> <span class="mf">0.4</span><span class="p">]])</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">f1</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.33333334)</span>
</code></pre></div>






              <details class="quote">
                <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
                <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">303</span>
<span class="normal">304</span>
<span class="normal">305</span>
<span class="normal">306</span>
<span class="normal">307</span>
<span class="normal">308</span>
<span class="normal">309</span>
<span class="normal">310</span>
<span class="normal">311</span>
<span class="normal">312</span>
<span class="normal">313</span>
<span class="normal">314</span>
<span class="normal">315</span>
<span class="normal">316</span>
<span class="normal">317</span>
<span class="normal">318</span>
<span class="normal">319</span>
<span class="normal">320</span>
<span class="normal">321</span>
<span class="normal">322</span>
<span class="normal">323</span>
<span class="normal">324</span>
<span class="normal">325</span>
<span class="normal">326</span>
<span class="normal">327</span>
<span class="normal">328</span>
<span class="normal">329</span>
<span class="normal">330</span>
<span class="normal">331</span>
<span class="normal">332</span>
<span class="normal">333</span>
<span class="normal">334</span>
<span class="normal">335</span>
<span class="normal">336</span>
<span class="normal">337</span>
<span class="normal">338</span>
<span class="normal">339</span>
<span class="normal">340</span>
<span class="normal">341</span>
<span class="normal">342</span>
<span class="normal">343</span>
<span class="normal">344</span>
<span class="normal">345</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">F1Score</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;Class to calculate the f1 score for both binary and categorical problems</span>

<span class="sd">    Parameters</span>
<span class="sd">    ----------</span>
<span class="sd">    average: bool, default = True</span>
<span class="sd">        This applies only to multiclass problems. if ``True`` calculate f1 for</span>
<span class="sd">        each label, and find their unweighted mean.</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">    &gt;&gt;&gt; import torch</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.metrics import F1Score</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; f1 = F1Score()</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 0, 1]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.3, 0.2, 0.6, 0.7]]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; f1(y_pred, y_true)</span>
<span class="sd">    array(0.5)</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; f1 = F1Score()</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([0, 1, 2])</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([[0.7, 0.1, 0.2], [0.1, 0.1, 0.8], [0.1, 0.5, 0.4]])</span>
<span class="sd">    &gt;&gt;&gt; f1(y_pred, y_true)</span>
<span class="sd">    array(0.33333334)</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">average</span><span class="p">:</span> <span class="nb">bool</span> <span class="o">=</span> <span class="kc">True</span><span class="p">):</span>
        <span class="nb">super</span><span class="p">(</span><span class="n">F1Score</span><span class="p">,</span> <span class="bp">self</span><span class="p">)</span><span class="o">.</span><span class="fm">__init__</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">average</span> <span class="o">=</span> <span class="n">average</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">f1</span> <span class="o">=</span> <span class="n">FBetaScore</span><span class="p">(</span><span class="n">beta</span><span class="o">=</span><span class="mi">1</span><span class="p">,</span> <span class="n">average</span><span class="o">=</span><span class="bp">self</span><span class="o">.</span><span class="n">average</span><span class="p">)</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">f1</span><span class="o">.</span><span class="n">_name</span>

    <span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        resets counters to 0</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">f1</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>

    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y_true</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
        <span class="k">return</span> <span class="bp">self</span><span class="o">.</span><span class="n">f1</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
</code></pre></div></td></tr></table></div>
              </details>



  <div class="doc doc-children">









<div class="doc doc-object doc-function">


<h3 id="pytorch_widedeep.metrics.F1Score.reset" class="doc doc-heading">
            <span class="doc doc-object-name doc-function-name">reset</span>


<a href="#pytorch_widedeep.metrics.F1Score.reset" class="headerlink" title="Permanent link">&para;</a></h3>
<div class="doc-signature highlight"><pre><span></span><code><span class="nf">reset</span><span class="p">()</span>
</code></pre></div>

    <div class="doc doc-contents ">

        <p>resets counters to 0</p>

            <details class="quote">
              <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
              <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">338</span>
<span class="normal">339</span>
<span class="normal">340</span>
<span class="normal">341</span>
<span class="normal">342</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    resets counters to 0</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">f1</span><span class="o">.</span><span class="n">reset</span><span class="p">()</span>
</code></pre></div></td></tr></table></div>
            </details>
    </div>

</div>



  </div>

    </div>

</div>

<div class="doc doc-object doc-class">



<h2 id="pytorch_widedeep.metrics.R2Score" class="doc doc-heading">
            <span class="doc doc-object-name doc-class-name">R2Score</span>


<a href="#pytorch_widedeep.metrics.R2Score" class="headerlink" title="Permanent link">&para;</a></h2>


    <div class="doc doc-contents first">
            <p class="doc doc-class-bases">
              Bases: <code><span title="pytorch_widedeep.metrics.Metric">Metric</span></code></p>


        <p>Calculates R-Squared, the
<a href="https://en.wikipedia.org/wiki/Coefficient_of_determination&gt;">coefficient of determination</a>:</p>
<div class="arithmatex">\[
R^2 = 1 - \frac{\sum_{j=1}^n(y_j - \hat{y_j})^2}{\sum_{j=1}^n(y_j - \bar{y})^2}
\]</div>
<p>where <span class="arithmatex">\(\hat{y_j}\)</span> is the ground truth, <span class="arithmatex">\(y_j\)</span> is the predicted value and
<span class="arithmatex">\(\bar{y}\)</span> is the mean of the ground truth.</p>


<p><span class="doc-section-title">Examples:</span></p>
    <div class="highlight"><pre><span></span><code><span class="gp">&gt;&gt;&gt; </span><span class="kn">import</span> <span class="nn">torch</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="kn">from</span> <span class="nn">pytorch_widedeep.metrics</span> <span class="kn">import</span> <span class="n">R2Score</span>
<span class="gp">&gt;&gt;&gt;</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">r2</span> <span class="o">=</span> <span class="n">R2Score</span><span class="p">()</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_true</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mi">3</span><span class="p">,</span> <span class="o">-</span><span class="mf">0.5</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">7</span><span class="p">])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">y_pred</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">tensor</span><span class="p">([</span><span class="mf">2.5</span><span class="p">,</span> <span class="mf">0.0</span><span class="p">,</span> <span class="mi">2</span><span class="p">,</span> <span class="mi">8</span><span class="p">])</span><span class="o">.</span><span class="n">view</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="gp">&gt;&gt;&gt; </span><span class="n">r2</span><span class="p">(</span><span class="n">y_pred</span><span class="p">,</span> <span class="n">y_true</span><span class="p">)</span>
<span class="go">array(0.94860814)</span>
</code></pre></div>






              <details class="quote">
                <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
                <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">348</span>
<span class="normal">349</span>
<span class="normal">350</span>
<span class="normal">351</span>
<span class="normal">352</span>
<span class="normal">353</span>
<span class="normal">354</span>
<span class="normal">355</span>
<span class="normal">356</span>
<span class="normal">357</span>
<span class="normal">358</span>
<span class="normal">359</span>
<span class="normal">360</span>
<span class="normal">361</span>
<span class="normal">362</span>
<span class="normal">363</span>
<span class="normal">364</span>
<span class="normal">365</span>
<span class="normal">366</span>
<span class="normal">367</span>
<span class="normal">368</span>
<span class="normal">369</span>
<span class="normal">370</span>
<span class="normal">371</span>
<span class="normal">372</span>
<span class="normal">373</span>
<span class="normal">374</span>
<span class="normal">375</span>
<span class="normal">376</span>
<span class="normal">377</span>
<span class="normal">378</span>
<span class="normal">379</span>
<span class="normal">380</span>
<span class="normal">381</span>
<span class="normal">382</span>
<span class="normal">383</span>
<span class="normal">384</span>
<span class="normal">385</span>
<span class="normal">386</span>
<span class="normal">387</span>
<span class="normal">388</span>
<span class="normal">389</span>
<span class="normal">390</span>
<span class="normal">391</span>
<span class="normal">392</span>
<span class="normal">393</span>
<span class="normal">394</span>
<span class="normal">395</span>
<span class="normal">396</span>
<span class="normal">397</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">class</span> <span class="nc">R2Score</span><span class="p">(</span><span class="n">Metric</span><span class="p">):</span>
<span class="w">    </span><span class="sa">r</span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    Calculates R-Squared, the</span>
<span class="sd">    [coefficient of determination](https://en.wikipedia.org/wiki/Coefficient_of_determination&gt;):</span>

<span class="sd">    $$</span>
<span class="sd">    R^2 = 1 - \frac{\sum_{j=1}^n(y_j - \hat{y_j})^2}{\sum_{j=1}^n(y_j - \bar{y})^2}</span>
<span class="sd">    $$</span>

<span class="sd">    where $\hat{y_j}$ is the ground truth, $y_j$ is the predicted value and</span>
<span class="sd">    $\bar{y}$ is the mean of the ground truth.</span>

<span class="sd">    Examples</span>
<span class="sd">    --------</span>
<span class="sd">    &gt;&gt;&gt; import torch</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; from pytorch_widedeep.metrics import R2Score</span>
<span class="sd">    &gt;&gt;&gt;</span>
<span class="sd">    &gt;&gt;&gt; r2 = R2Score()</span>
<span class="sd">    &gt;&gt;&gt; y_true = torch.tensor([3, -0.5, 2, 7]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; y_pred = torch.tensor([2.5, 0.0, 2, 8]).view(-1, 1)</span>
<span class="sd">    &gt;&gt;&gt; r2(y_pred, y_true)</span>
<span class="sd">    array(0.94860814)</span>
<span class="sd">    &quot;&quot;&quot;</span>

    <span class="k">def</span> <span class="fm">__init__</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">numerator</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">denominator</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_examples</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">y_true_sum</span> <span class="o">=</span> <span class="mi">0</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">_name</span> <span class="o">=</span> <span class="s2">&quot;r2&quot;</span>

    <span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">        </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">        resets counters to 0</span>
<span class="sd">        &quot;&quot;&quot;</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">numerator</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">denominator</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">num_examples</span> <span class="o">=</span> <span class="mi">0</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">y_true_sum</span> <span class="o">=</span> <span class="mi">0</span>

    <span class="k">def</span> <span class="fm">__call__</span><span class="p">(</span><span class="bp">self</span><span class="p">,</span> <span class="n">y_pred</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">,</span> <span class="n">y_true</span><span class="p">:</span> <span class="n">Tensor</span><span class="p">)</span> <span class="o">-&gt;</span> <span class="n">np</span><span class="o">.</span><span class="n">ndarray</span><span class="p">:</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">numerator</span> <span class="o">+=</span> <span class="p">((</span><span class="n">y_pred</span> <span class="o">-</span> <span class="n">y_true</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>

        <span class="bp">self</span><span class="o">.</span><span class="n">num_examples</span> <span class="o">+=</span> <span class="n">y_true</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mi">0</span><span class="p">]</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">y_true_sum</span> <span class="o">+=</span> <span class="n">y_true</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
        <span class="n">y_true_avg</span> <span class="o">=</span> <span class="bp">self</span><span class="o">.</span><span class="n">y_true_sum</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">num_examples</span>
        <span class="bp">self</span><span class="o">.</span><span class="n">denominator</span> <span class="o">+=</span> <span class="p">((</span><span class="n">y_true</span> <span class="o">-</span> <span class="n">y_true_avg</span><span class="p">)</span> <span class="o">**</span> <span class="mi">2</span><span class="p">)</span><span class="o">.</span><span class="n">sum</span><span class="p">()</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
        <span class="k">return</span> <span class="n">np</span><span class="o">.</span><span class="n">array</span><span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">numerator</span> <span class="o">/</span> <span class="bp">self</span><span class="o">.</span><span class="n">denominator</span><span class="p">)))</span>
</code></pre></div></td></tr></table></div>
              </details>



  <div class="doc doc-children">









<div class="doc doc-object doc-function">


<h3 id="pytorch_widedeep.metrics.R2Score.reset" class="doc doc-heading">
            <span class="doc doc-object-name doc-function-name">reset</span>


<a href="#pytorch_widedeep.metrics.R2Score.reset" class="headerlink" title="Permanent link">&para;</a></h3>
<div class="doc-signature highlight"><pre><span></span><code><span class="nf">reset</span><span class="p">()</span>
</code></pre></div>

    <div class="doc doc-contents ">

        <p>resets counters to 0</p>

            <details class="quote">
              <summary>Source code in <code>pytorch_widedeep/metrics.py</code></summary>
              <div class="highlight"><table class="highlighttable"><tr><td class="linenos"><div class="linenodiv"><pre><span></span><span class="normal">381</span>
<span class="normal">382</span>
<span class="normal">383</span>
<span class="normal">384</span>
<span class="normal">385</span>
<span class="normal">386</span>
<span class="normal">387</span>
<span class="normal">388</span></pre></div></td><td class="code"><div><pre><span></span><code><span class="k">def</span> <span class="nf">reset</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w">    </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd">    resets counters to 0</span>
<span class="sd">    &quot;&quot;&quot;</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">numerator</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">denominator</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">num_examples</span> <span class="o">=</span> <span class="mi">0</span>
    <span class="bp">self</span><span class="o">.</span><span class="n">y_true_sum</span> <span class="o">=</span> <span class="mi">0</span>
</code></pre></div></td></tr></table></div>
            </details>
    </div>

</div>



  </div>

    </div>

</div>








  <aside class="md-source-file">
    
    
    
      
  
  <span class="md-source-file__fact">
    <span class="md-icon" title="Contributors">
      
        <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24"><path d="M12 5.5A3.5 3.5 0 0 1 15.5 9a3.5 3.5 0 0 1-3.5 3.5A3.5 3.5 0 0 1 8.5 9 3.5 3.5 0 0 1 12 5.5M5 8c.56 0 1.08.15 1.53.42-.15 1.43.27 2.85 1.13 3.96C7.16 13.34 6.16 14 5 14a3 3 0 0 1-3-3 3 3 0 0 1 3-3m14 0a3 3 0 0 1 3 3 3 3 0 0 1-3 3c-1.16 0-2.16-.66-2.66-1.62a5.54 5.54 0 0 0 1.13-3.96c.45-.27.97-.42 1.53-.42M5.5 18.25c0-2.07 2.91-3.75 6.5-3.75s6.5 1.68 6.5 3.75V20h-13zM0 20v-1.5c0-1.39 1.89-2.56 4.45-2.9-.59.68-.95 1.62-.95 2.65V20zm24 0h-3.5v-1.75c0-1.03-.36-1.97-.95-2.65 2.56.34 4.45 1.51 4.45 2.9z"/></svg>
      
    </span>
    <nav>
      
        <a href="mailto:javierrodriguezzaurin@javiers-macbook-pro.local">Javier Rodriguez Zaurin</a>, 
        <a href="mailto:mulinka.pavol@gmail.com">Pavol Mulinka</a>
    </nav>
  </span>

    
    
  </aside>





                
              </article>
            </div>
          
          
<script>var target=document.getElementById(location.hash.slice(1));target&&target.name&&(target.checked=target.name.startsWith("__tabbed_"))</script>
        </div>
        
      </main>
      
        <footer class="md-footer">
  
  <div class="md-footer-meta md-typeset">
    <div class="md-footer-meta__inner md-grid">
      <div class="md-copyright">
  
    <div class="md-copyright__highlight">
      Javier Zaurin and Pavol Mulinka
    </div>
  
  
    Made with
    <a href="https://squidfunk.github.io/mkdocs-material/" target="_blank" rel="noopener">
      Material for MkDocs
    </a>
  
</div>
      
        <div class="md-social">
  
    
    
    
    
      
      
    
    <a href="https://jrzaurin.medium.com/" target="_blank" rel="noopener" title="jrzaurin.medium.com" class="md-social__link">
      <svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 640 512"><!--! Font Awesome Free 6.6.0 by @fontawesome - https://fontawesome.com License - https://fontawesome.com/license/free (Icons: CC BY 4.0, Fonts: SIL OFL 1.1, Code: MIT License) Copyright 2024 Fonticons, Inc.--><path d="M180.5 74.262C80.813 74.262 0 155.633 0 256s80.819 181.738 180.5 181.738S361 356.373 361 256 280.191 74.262 180.5 74.262m288.25 10.646c-49.845 0-90.245 76.619-90.245 171.095s40.406 171.1 90.251 171.1 90.251-76.619 90.251-171.1H559c0-94.503-40.4-171.095-90.248-171.095Zm139.506 17.821c-17.526 0-31.735 68.628-31.735 153.274s14.2 153.274 31.735 153.274S640 340.631 640 256c0-84.649-14.215-153.271-31.742-153.271Z"/></svg>
    </a>
  
</div>
      
    </div>
  </div>
</footer>
      
    </div>
    <div class="md-dialog" data-md-component="dialog">
      <div class="md-dialog__inner md-typeset"></div>
    </div>
    
    
    <script id="__config" type="application/json">{"base": "..", "features": ["navigation.tabs", "navigation.tabs.sticky", "navigation.indexes", "navigation.expand", "toc.integrate"], "search": "../assets/javascripts/workers/search.6ce7567c.min.js", "translations": {"clipboard.copied": "Copied to clipboard", "clipboard.copy": "Copy to clipboard", "search.result.more.one": "1 more on this page", "search.result.more.other": "# more on this page", "search.result.none": "No matching documents", "search.result.one": "1 matching document", "search.result.other": "# matching documents", "search.result.placeholder": "Type to start searching", "search.result.term.missing": "Missing", "select.version": "Select version"}}</script>
    
    
      <script src="../assets/javascripts/bundle.83f73b43.min.js"></script>
      
        <script src="../stylesheets/extra.js"></script>
      
        <script src="../javascripts/mathjax.js"></script>
      
        <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
      
        <script src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
      
    
  </body>
</html>